{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Problem: Quantize Your Language Model\n",
    "\n",
    "### Problem Statement\n",
    "Implement a **language model** using an LSTM and apply **dynamic quantization** to optimize it for inference. Dynamic quantization reduces the model size and enhances inference speed by quantizing the weights of the model.\n",
    "\n",
    "### Requirements\n",
    "\n",
    "1. **Define the Language Model**:\n",
    "   - **Purpose**: Build a simple language model that predicts the next token in a sequence.\n",
    "   - **Components**:\n",
    "     - **Embedding Layer**: Converts input tokens into dense vector representations.\n",
    "     - **LSTM Layer**: Processes the embedded sequence to capture temporal dependencies.\n",
    "     - **Fully Connected Layer**: Outputs predictions for the next token.\n",
    "     - **Softmax Layer**: Applies a probability distribution over the vocabulary for predictions.\n",
    "   - **Forward Pass**:\n",
    "     - Pass the input sequence through the embedding layer.\n",
    "     - Feed the embedded sequence into the LSTM.\n",
    "     - Use the final hidden state from the LSTM to make predictions via the fully connected layer.\n",
    "     - Apply the softmax function to obtain probabilities over the vocabulary.\n",
    "\n",
    "2. **Apply Dynamic Quantization**:\n",
    "   - Quantize the model dynamically\n",
    "   - Evaluate the quantized model's performance compared to the original model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.quantization import quantize_dynamic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: Define a simple Language Model (an LSTM-based model)\n",
    "class LanguageModel(nn.Module):\n",
    "    def __init__(self, vocab_size, embed_size, hidden_size, num_layers):\n",
    "        super(LanguageModel, self).__init__()\n",
    "        ...\n",
    "\n",
    "    def forward(self, x):\n",
    "        ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create synthetic training data\n",
    "torch.manual_seed(42)\n",
    "vocab_size = 50\n",
    "seq_length = 10\n",
    "batch_size = 32\n",
    "X_train = torch.randint(0, vocab_size, (batch_size, seq_length))  # Random integer input\n",
    "y_train = torch.randint(0, vocab_size, (batch_size,))  # Random target words\n",
    "\n",
    "# Initialize the model, loss function, and optimizer\n",
    "embed_size = 64\n",
    "hidden_size = 128\n",
    "num_layers = 2\n",
    "model = LanguageModel(vocab_size, embed_size, hidden_size, num_layers)\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/5] - Loss: 3.9118\n",
      "Epoch [2/5] - Loss: 3.9113\n",
      "Epoch [3/5] - Loss: 3.9108\n",
      "Epoch [4/5] - Loss: 3.9103\n",
      "Epoch [5/5] - Loss: 3.9097\n"
     ]
    }
   ],
   "source": [
    "# Training loop\n",
    "epochs = 5\n",
    "for epoch in range(epochs):\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    output = model(X_train)\n",
    "    loss = criterion(output, y_train)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    # Log progress every epoch\n",
    "    print(f\"Epoch [{epoch + 1}/{epochs}] - Loss: {loss.item():.4f}\")\n",
    "\n",
    "# Now, we will quantize the model dynamically to reduce its size and improve inference speed\n",
    "# Quantization: Apply dynamic quantization to the language model\n",
    "quantized_model = quantize_dynamic(model, {nn.Linear, nn.LSTM}, dtype=torch.qint8)\n",
    "\n",
    "# Save the quantized model\n",
    "torch.save(quantized_model.state_dict(), \"quantized_language_model.pth\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Load the quantized model and test it\n",
    "quantized_model = LanguageModel(vocab_size, embed_size, hidden_size, num_layers)\n",
    "\n",
    "# Apply dynamic quantization on the model after defining it\n",
    "quantized_model = quantize_dynamic(quantized_model, {nn.Linear, nn.LSTM}, dtype=torch.qint8)\n",
    "\n",
    "quantized_model.load_state_dict(torch.load(\"quantized_language_model.pth\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prediction for input [[15, 28, 33, 19, 37, 24, 48, 42, 33, 35]]: 49\n"
     ]
    }
   ],
   "source": [
    "# Testing the quantized model on a sample input\n",
    "quantized_model.eval()\n",
    "test_input = torch.randint(0, vocab_size, (1, seq_length))\n",
    "with torch.no_grad():\n",
    "    prediction = quantized_model(test_input)\n",
    "    print(f\"Prediction for input {test_input.tolist()}: {prediction.argmax(dim=1).item()}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
